#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Run the LDBC SNB interactive workload / benchmark.
The benchmark is executed with:
    * ldbc_driver -> workload executor
    * ldbc-snb-impls/snb-interactive-neo4j -> workload implementation
"""

import argparse
import os
import shutil
import subprocess
import tempfile
import time

SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))


def wait_for_server(port, delay=1.0):
    cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
    while subprocess.call(cmd) != 0:
        time.sleep(0.5)
    time.sleep(delay)


class Memgraph:
    def __init__(self, dataset, port, num_workers):
        self.proc = None
        self.dataset = dataset
        self.port = str(port)
        self.num_workers = str(num_workers)

    def start(self):
        # find executable path
        binary = os.path.join(BASE_DIR, "build", "memgraph")

        # database args
        database_args = [
            binary,
            "--bolt-num-workers",
            self.num_workers,
            "--data-directory",
            os.path.join(self.dataset, "memgraph"),
            "--data-recovery-on-startup",
            "true",
            "--bolt-port",
            self.port,
        ]

        # database env
        env = {"MEMGRAPH_CONFIG": os.path.join(SCRIPT_DIR, "config", "memgraph.conf")}

        # start memgraph
        self.proc = subprocess.Popen(database_args, env=env)
        wait_for_server(self.port)

    def stop(self):
        self.proc.terminate()
        if self.proc.wait() != 0:
            raise Exception("Database exited with non-zero exit code!")


class Neo:
    def __init__(self, dataset, port):
        self.proc = None
        self.dataset = dataset
        self.port = str(port)
        self.http_port = str(int(port) + 7474)
        self.home_dir = None

    def start(self):
        # create home directory
        self.home_dir = tempfile.mkdtemp(dir=self.dataset)

        neo4j_dir = os.path.join(BASE_DIR, "libs", "neo4j")

        try:
            os.symlink(os.path.join(neo4j_dir, "lib"), os.path.join(self.home_dir, "lib"))
            shutil.copytree(os.path.join(self.dataset, "neo4j"), os.path.join(self.home_dir, "data"))
            conf_dir = os.path.join(self.home_dir, "conf")
            conf_file = os.path.join(conf_dir, "neo4j.conf")
            os.mkdir(conf_dir)
            shutil.copyfile(os.path.join(SCRIPT_DIR, "config", "neo4j.conf"), conf_file)
            with open(conf_file, "a") as f:
                f.write("\ndbms.connector.bolt.listen_address=:" + self.port + "\n")
                f.write("\ndbms.connector.http.listen_address=:" + self.http_port + "\n")

            # environment
            env = {"NEO4J_HOME": self.home_dir}

            self.proc = subprocess.Popen([os.path.join(neo4j_dir, "bin", "neo4j"), "console"], env=env, cwd=neo4j_dir)
        except:
            shutil.rmtree(self.home_dir)
            raise Exception("Couldn't run Neo4j!")

        wait_for_server(self.http_port, 2.0)

    def stop(self):
        self.proc.terminate()
        ret = self.proc.wait()
        if os.path.exists(self.home_dir):
            shutil.rmtree(self.home_dir)
        if ret != 0:
            raise Exception("Database exited with non-zero exit code!")


def parse_args():
    argp = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    argp.add_argument("--scale", type=int, default=1, help="Dataset scale to use for benchmarking.")
    argp.add_argument("--host", default="127.0.0.1", help="Database host.")
    argp.add_argument("--port", default="7687", help="Database port.")
    argp.add_argument(
        "--time-compression-ratio",
        type=float,
        default=0.001,
        help="Compress/stretch durations between operation start "
        "times to increase/decrease benchmark load. "
        "E.g. 2.0 = run benchmark 2x slower, 0.1 = run "
        "benchmark 10x faster. Default is 0.001.",
    )
    argp.add_argument(
        "--operation-count",
        type=int,
        default=1000,
        help="Number of operations to generate during benchmark " "execution.",
    )
    argp.add_argument(
        "--thread-count", type=int, default=8, help="Thread pool size to use for executing operation " "handlers."
    )
    argp.add_argument(
        "--time-unit",
        default="microseconds",
        choices=("nanoseconds", "microseconds", "milliseconds", "seconds", "minutes"),
        help="Time unit to use for measuring performance metrics",
    )
    argp.add_argument("--result-file-prefix", default="", help="Result file name prefix")
    argp.add_argument("--test-type", choices=("reads", "updates"), default="reads", help="Test queries of type")
    argp.add_argument("--run-db", choices=("memgraph", "neo4j"), help="Run the database before starting LDBC")
    argp.add_argument(
        "--create-index", action="store_true", default=False, help="Create index in the running database."
    )
    return argp.parse_args()


LDBC_INTERACTIVE_NEO4J = os.path.join(
    SCRIPT_DIR,
    "ldbc-snb-impls",
    "snb-interactive-neo4j",
    "target",
    "snb-interactive-neo4j-1.0.0-jar-with-dependencies.jar",
)
LDBC_DEFAULT_PROPERTIES = os.path.join(SCRIPT_DIR, "ldbc_driver", "configuration", "ldbc_driver_default.properties")


def create_index(port, database):
    index_file = os.path.join(SCRIPT_DIR, "ldbc-snb-impls", "snb-interactive-neo4j", "scripts", "indexCreation.neo4j")
    subprocess.check_call(
        ("ve3/bin/python3", "index_creation", "--port", port, "--database", database, index_file), cwd=SCRIPT_DIR
    )
    time.sleep(1.0)


def main():
    args = parse_args()
    dataset = os.path.join(SCRIPT_DIR, "datasets", "scale_" + str(args.scale))

    db = None
    if args.run_db:
        if args.host != "127.0.0.1":
            raise Exception("Host parameter must point to localhost when " "this script starts the database!")
        if args.run_db.lower() == "memgraph":
            db = Memgraph(dataset, args.port, args.thread_count)
        elif args.run_db.lower() == "neo4j":
            db = Neo(dataset, args.port)

    try:
        if db:
            db.start()
            time.sleep(10.0)

        if args.create_index:
            create_index(args.port, args.run_db.lower())
            time.sleep(10.0)

        # Run LDBC driver.
        cp = "target/jeeves-0.3-SNAPSHOT.jar:{}".format(LDBC_INTERACTIVE_NEO4J)
        updates_dir = os.path.join(dataset, "social_network")
        parameters_dir = os.path.join(dataset, "substitution_parameters")
        java_cmd = (
            "java",
            "-cp",
            cp,
            "com.ldbc.driver.Client",
            "-P",
            LDBC_DEFAULT_PROPERTIES,
            "-P",
            os.path.join(SCRIPT_DIR, "ldbc-snb-impls-{}." "properties".format(args.test_type)),
            "-p",
            "ldbc.snb.interactive.updates_dir",
            updates_dir,
            "-p",
            "host",
            args.host,
            "-p",
            "port",
            args.port,
            "-db",
            "net.ellitron.ldbcsnbimpls.interactive.neo4j.Neo4jDb",
            "-p",
            "ldbc.snb.interactive.parameters_dir",
            parameters_dir,
            "--time_compression_ratio",
            str(args.time_compression_ratio),
            "--operation_count",
            str(args.operation_count),
            "--thread_count",
            str(args.thread_count),
            "--time_unit",
            args.time_unit.upper(),
        )
        subprocess.check_call(java_cmd, cwd=os.path.join(SCRIPT_DIR, "ldbc_driver"))

        # Copy the results to results dir.
        ldbc_results = os.path.join(SCRIPT_DIR, "ldbc_driver", "results", "LDBC-results.json")
        results_dir = os.path.join(SCRIPT_DIR, "results")
        results_name = []
        if args.result_file_prefix:
            results_name.append(args.result_file_prefix)
        if args.run_db:
            results_name.append(args.run_db)
        else:
            results_name.append("external")
        results_name.append("scale_" + str(args.scale))
        results_name = "-".join(results_name + ["LDBC", "results.json"])
        results_copy = os.path.join(results_dir, results_name)
        shutil.copyfile(ldbc_results, results_copy)

        print("Results saved to:", results_copy)

    finally:
        if db:
            db.stop()

    print("Done!")


if __name__ == "__main__":
    main()
